import os
import threading
import time

import numpy as np
import tensorflow as tf
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow_model_optimization.sparsity import keras as sparsity

from tool import common


def pruning_thread():
    def pruning():

        model = tf.keras.models.load_model('E:\\workspace python\\tf2\\LWHD2\\src\\model\\tic_0305.h5')

        epoch_num = 12
        end_step = np.ceil(1.0 * 6845 * 4 / 64).astype(np.int32) * epoch_num
        print(end_step)

        new_pruning_params = {
            'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.50,
                                                         final_sparsity=0.90,
                                                         begin_step=0,
                                                         end_step=end_step,
                                                         frequency=100)
        }

        new_pruned_model = sparsity.prune_low_magnitude(model, **new_pruning_params)
        new_pruned_model.summary()

        learning_rate = 1e-3 * 0.8

        def scheduler(epoch):
            if epoch < epoch_num * 0.4:
                return learning_rate
            if epoch < epoch_num * 0.7:
                return learning_rate * 0.1
            return learning_rate * 0.02

        sgd = optimizers.SGD(lr=learning_rate, momentum=0.9, nesterov=True)
        change_lr = LearningRateScheduler(scheduler)
        new_pruned_model.compile(sgd, loss="sparse_categorical_crossentropy", metrics=["accuracy"])

        logdir = os.path.join("tflog")
        print('Writing training logs to ' + logdir)

        callbacks = [
            sparsity.UpdatePruningStep(),
            sparsity.PruningSummaries(log_dir=logdir, profile_batch=0),
            change_lr
        ]

        new_pruned_model.fit(common.dataProcess_tic_type.get_dataset_train(),
                             epochs=epoch_num,
                             verbose=1,
                             callbacks=callbacks)

        score = new_pruned_model.evaluate(common.dataProcess_tic_type.get_dataset_test())

        print('Test loss:', score[0])
        print('Test accuracy:', score[1])

        final_model = sparsity.strip_pruning(new_pruned_model)
        final_model.summary()

        new_pruned_keras_file = 'E:\\workspace python\\tf2\\LWHD2\\src\\model\\pruned_0305.h5'
        tf.keras.models.save_model(final_model, new_pruned_keras_file,
                                   include_optimizer=False)

        print("Size of the pruned model before compression: %.2f Mb"
              % (os.path.getsize(new_pruned_keras_file) / float(2 ** 20)))

    th = threading.Thread(target=pruning)
    th.start()


def pruning_axle_thread():
    def pruning():

        model = tf.keras.models.load_model('E:\\workspace python\\tf2\\LWHD2\\src\\model\\axle_resnet_auto.h5')

        train_data_size = 10445
        epoch_num = 10
        end_step = np.ceil(train_data_size / 64).astype(np.int32) * epoch_num
        print(end_step)

        new_pruning_params = {
            'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.50,
                                                         final_sparsity=0.95,
                                                         begin_step=0,
                                                         end_step=end_step)
        }

        new_pruned_model = sparsity.prune_low_magnitude(model, **new_pruning_params)
        # new_pruned_model.summary()

        learning_rate = 1e-3 * 0.8

        def scheduler(epoch):
            if epoch < epoch_num * 0.4:
                return learning_rate
            if epoch < epoch_num * 0.7:
                return learning_rate * 0.1
            return learning_rate * 0.02

        # sgd = optimizers.SGD(lr=learning_rate, momentum=0.9, nesterov=True)
        # change_lr = LearningRateScheduler(scheduler)
        new_pruned_model.compile(optimizer='adam', loss="sparse_categorical_crossentropy", metrics=["accuracy"])
        # filepath = "./model/axle_pruning_auto.h5"
        # checkpoint = ModelCheckpoint(filepath, monitor='accuracy', verbose=1, save_best_only=True, mode='max')
        # earlysStop = tf.keras.callbacks.EarlyStopping(monitor='accuracy', verbose=1, patience=1, mode='auto')

        logdir = os.path.join("tflog/pruningalxe_{}".format(time.strftime("%m%d%H%M", time.localtime())))
        print('Writing training logs to ' + logdir)

        callbacks = [
            sparsity.UpdatePruningStep(),
            sparsity.PruningSummaries(log_dir=logdir, profile_batch=0)
            # change_lr
            # earlysStop
        ]

        new_pruned_model.fit(common.dataProcess_axle_type.get_dataset_train(),
                             epochs=epoch_num,
                             verbose=1,
                             callbacks=callbacks)

        score = new_pruned_model.evaluate(common.dataProcess_axle_type.get_dataset_test())

        print('Test loss:', score[0])
        print('Test accuracy:', score[1])

        final_model = sparsity.strip_pruning(new_pruned_model)
        ##final_model.summary()

        new_pruned_keras_file = 'E:\\workspace python\\tf2\\LWHD2\\src\\model\\pruned_axle_resnet_auto.h5'
        tf.keras.models.save_model(final_model, new_pruned_keras_file,
                                   include_optimizer=False)

        print("Size of the pruned model before compression: %.2f Mb"
              % (os.path.getsize(new_pruned_keras_file) / float(2 ** 20)))

    th = threading.Thread(target=pruning)
    th.start()
